# -*- coding: utf-8 -*-
# !/usr/bin/env python
"""
-------------------------------------------------
   File Name：     model
   Description :   
   Author :       lth
   date：          2023/1/30
-------------------------------------------------
   Change Activity:
                   2023/1/30 10:28: create this script
-------------------------------------------------
"""
__author__ = 'lth'

from torch import nn
from transformers import BertModel


class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5,type="train"):
        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased",mirror ="tuna")
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()

    def forward(self, x, mask):
        _, pooled_output = self.bert(input_ids=x, attention_mask=mask, return_dict=False)
        return self.relu(self.linear(self.dropout(pooled_output)))
